# GBA Song Editor
#
# Copyright 2010-2011 Karl A. Knechtel.
#
# Support for MIDI -> M4A (GBA native format) conversion.
#
# Licensed under the Generic Non-Commercial Copyleft Software License,
# Version 1.1 (hereafter "Licence"). You may not use this file except
# in the ways outlined in the Licence, which you should have received
# along with this file.
#
# Unless required by applicable law or agreed to in writing, software 
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
# implied. See the License for the specific language governing
# permissions and limitations under the License.


from cStringIO import StringIO
from util import PEBKAC, Chunk, as_pointer


# Some constants related to the two encoding schemes.
GBA_TIME_CODES = [
	1,  2,  3,  4,  5,  6,  7,  8,  9,  10, 11, 12, 13, 14, 15, 16,
	17, 18, 19, 20, 21, 22, 23, 24, 28, 30, 32, 36, 40, 42, 44, 48,
	52, 54, 56, 60, 64, 66, 68, 72, 76, 78, 80, 84, 88, 90, 92, 96
]
GBA_TICKS_PER_BEAT = 24
GBA_OFF = '\xce'
GBA_ON = '\xcf'
GBA_INSTRUMENT = '\xbd'
GBA_VOLUME = '\xbe'
GBA_BPM = '\xbb'
GBA_PERCUSSION_INSTRUMENT = 0x7F

MIDI_END_OF_TRACK = 0x2F
MIDI_BPM = 0x51
MIDI_PERCUSSION_CHANNEL = 9


def ticks_to_bytes(ticks):
	# Repeatedly represent as much as possible of the remaining time with a single note (greedy algorithm)
	result = ''
	assert ticks >= 0
	while ticks:
		best = max(x for x in GBA_TIME_CODES if x <= ticks)
		result += chr(0x81 + GBA_TIME_CODES.index(best))
		ticks -= best
	return result


def clip_filter(events, begin, end):
	for gba_time, event_type, data in sorted(events):
		if gba_time >= end: break
		yield (max(0, gba_time - begin), event_type, data)
	# A bogus sentinel event for fix_events().
	yield (end - begin, None, None)


def fix_events(events):
	keys_on = {} # {pitch: volume}
	current_time = -1
	event_queue = []
	for gba_time, event_type, data in events:
		if gba_time > current_time:
			for queued_type, queued_data in event_queue:
				if queued_type == GBA_OFF:
					if queued_data[0] in keys_on: del keys_on[queued_data[0]]
					else: continue # note is already off; suppress this event
				elif queued_type == GBA_ON:
					keys_on[queued_data[0]] = queued_data[1]
				yield (current_time, queued_type, queued_data)
			current_time = gba_time

		# Handle the event queue.
		# The queue may contain an off event, possibly followed by an on,
		# for each note.
		if event_type in (GBA_ON, GBA_OFF):
			# cancel any simultaneous key-on or key-off events for this key
			# and override them appropriately.
			event_queue = [
				x for x in event_queue if not (
					x[0] in (GBA_ON, GBA_OFF) and x[1][0] == data[0]
				)
			]
			event_queue.append((GBA_OFF, (data[0],)))
			if event_type == GBA_ON: event_queue.append((GBA_ON, data))
		else:
			event_queue.append((event_type, data))


class MIDI_data(object):
	def __init__(self, ticks_per_quarter, begin, end):
		self.channels = {} # mapping of channel number to event list
		self.ticks_per_quarter = ticks_per_quarter
		# BPM-change messages will get temporarily inserted into every GBA track
		# so that they can be used to do time calculations; but they will only
		# be reflected in the converted output for the first track.
		self.tempo = []
		self.ignored = {}
		self.begin = 0 if begin == None else begin
		self.end = end
		self.auto_end = None


	def _warn(self, message):
		self.ignored[message] = self.ignored.get(message, 0) + 1


	def _process_meta(self, gba_time, channel_id, data):
		if channel_id != 0xF:
			self._warn('non-metadata system message')
		else:
			if data[0] == MIDI_END_OF_TRACK: # end of track marker
				if len(data) != 2 or data[1] != 0:
					raise PEBKAC, "MIDI format error (bad end-of-track code)"
				if self.auto_end == None or gba_time > self.auto_end:
					self.auto_end = gba_time
				return True
			elif data[0] == MIDI_BPM:
				if len(data) != 5 or data[1] != 3:
					raise PEBKAC, "MIDI format error (bad BPM code)"
				# convert microseconds per quarter note into quarter notes per minute.
				bpm = 60000000 / ((data[2] << 16) | (data[3] << 8) | data[4])
				# The Sappy format expects to store half the BPM for the "tempo" value.
				self.tempo.append((gba_time, GBA_BPM, (bpm / 2,)))
			else:
				self._warn('unimportant metadata')

		return False


	def add(self, midi_time, event, channel_id, data):
		gba_time = int(midi_time * GBA_TICKS_PER_BEAT / self.ticks_per_quarter)

		if event == 0xF: return self._process_meta(gba_time, channel_id, data)

		channel = self.channels.setdefault(channel_id, [])
		if event == 0x8: # note off
			if data[1] != 0:
				self._warn('nonzero velocity for note-off')
			channel.append((gba_time, GBA_OFF, (data[0],)))
		elif event == 0x9: # note on
			if data[1] == 0:
				channel.append((gba_time, GBA_OFF, (data[0],)))
			else:
				channel.append((gba_time, GBA_ON, data))
		elif event == 0xA: # aftertouch
			self._warn('aftertouch')
		elif event == 0xB:
			if data[0] == 7: # volume
				channel.append((gba_time, GBA_VOLUME, (data[1],)))
			else:
				self._warn('control change') # TODO: implement common controls
		elif event == 0xC: # program change
			# HAX: GBA instrument mapping tends to put the percussion on
			# instrument 127; by convention, MIDI channel 10 houses percussion
			# regardless of the instrument it specifies - note that channel
			# numbers are 1-based.
			channel.append((
				gba_time, GBA_INSTRUMENT,
				(GBA_PERCUSSION_INSTRUMENT,) if channel_id == MIDI_PERCUSSION_CHANNEL else data
			))
		elif event == 0xD: # channel after-touch
			self._warn('channel after-touch')
		elif event == 0xE:
			self._warn('pitch wheel change')

		return False


	def warnings(self):
		return [
			"Ignored %s [%d time(s)]." % (message, count)
			for message, count in self.ignored.items()
		]


	def converted(self, loop_ticks):
		if self.end == None: self.end = self.auto_end
		if self.end == None: # still?!
			# should be impossible... track ends are checked for elsewhere
			raise PEBKAC, "MIDI format error (could not determine song length)"
		if loop_ticks > self.end:
			raise PEBKAC("Loop point beyond end of track")

		return dict(
			(k, self._convert_track(v + self.tempo, k == min(self.channels.keys()), loop_ticks))
			for (k, v) in self.channels.items()
		)


	def _convert_track(self, track, first, loop_ticks):
		result = StringIO()
		# No key shift; full volume; center pan.
		# Is this data really needed?
		result.write('\xbc\x00\xbe\x7f\xbf\x40')
		current_ticks = 0
		loop_bytes = None
		current_bpm = None
		for item in fix_events(clip_filter(track, self.begin, self.end)):
			gba_time, event, data = item
			# Generate wait time, if any.
			delay_ticks = gba_time - current_ticks
			# Accumulate actually accounted-for time.
			current_ticks += delay_ticks
			# ticks_to_bytes will produce an empty string for zero-length intervals. :)
			if loop_ticks != None and loop_bytes == None and current_ticks >= loop_ticks:
				overflow = current_ticks - loop_ticks
				result.write(ticks_to_bytes(delay_ticks - overflow))
				loop_bytes = len(result.getvalue())
				result.write(ticks_to_bytes(overflow))
			else:
				result.write(ticks_to_bytes(delay_ticks))

			# Special handling for bpm-change events. Only emit them on the first track.
			# Also, ignore an event that repeats the current bpm.
			if event == GBA_BPM:
				if data == current_bpm: continue
				if not first: continue
				# Set the BPM for further iterations, and proceed to emit the event.
				current_bpm = data
			result.write(event + ''.join(chr(b) for b in data))

		# Add loop if it was requested and the track actually has data beyond the loop point.
		if loop_ticks != None and loop_bytes != None:
			# We must first add a delay until the common end time.
			result.write(ticks_to_bytes(self.end - self.begin - current_ticks))
			result.write('\xb2')
			for i in range(4):
				result.write(chr(loop_bytes % 0xFF))
				loop_bytes >>= 8

		# Terminate the track.
		result.write('\xb1')
		result.write('\x00' * (-result.tell() % 4)) # padding.
		return result.getvalue()


def expect(source, expected, message):
	if source.read(len(expected)) != expected:
		raise PEBKAC, message


def byte(source):
	return ord(source.read(1))


def encoded_value(source):
	result = 0
	more = 1
	while more:
		current = byte(source)
		result = (result << 7) | (current & 0x7F)
		more = current >> 7
	return result


def value(source, count):
	result = 0
	for x in xrange(count):
		result = (result << 8) | byte(source)
	return result


def as_little_endian(value, count):
	result = ''
	for x in xrange(count):
		result += chr(value & 0xFF)
		value >>= 8
	return result


def read_info(status, next_byte, source):
	if status == 0xFF:
		command_length = byte(source)
		command_data = tuple(byte(source) for x in xrange(command_length))
		return (next_byte, command_length) + command_data

	if status == 0xF0: # skip SysEx message
		while next_byte != 0xF7:
			next_byte = byte(source)
		return ()

	status_type = status >> 4
	if status_type == 0xF:
		raise PEBKAC, "Unexpected status byte: %2x" % status

	assert status_type >= 0x8 and status_type <= 0xE

	if status_type in (0xC, 0xD): return next_byte,
	return (next_byte, byte(source))


def read_event(source, previous_status):
	delay = encoded_value(source)
	next_byte = byte(source)
	if next_byte < 0x80:
		status = previous_status
	else:
		status = next_byte
		next_byte = ord(source.read(1))
	info = read_info(status, next_byte, source)
	return delay, status, info


def parse(midi, begin, end):
	# N.B. All numbers are big-endian in MIDI! This is the opposite of GBA stuff.
	# Magic signature - MThd + header length count, which should always be 6.
	expect(midi, 'MThd\x00\x00\x00\x06', 'bad MIDI header')
	file_type = value(midi, 2)
	track_count = value(midi, 2)
	if file_type == 0: assert track_count == 1
	result = MIDI_data(value(midi, 2), begin, end)
	for track in xrange(track_count):
		expect(midi, 'MTrk', 'bad track header')
		length = value(midi, 4)
		data = StringIO(midi.read(length))
		status = None
		time = 0
		done = False
		while not done and data.tell() != length:
			delay, status, info = read_event(data, status)
			time += delay
			done = result.add(time, status >> 4, status & 0xF, info)

		assert done and data.tell() == length, "track didn't end properly" # must happen at the same time.

	return result


def do_conversion(midi, instrument_map_offset, loop_ticks, begin, end):
	raw = parse(StringIO(midi), begin, end)
	status = raw.warnings()
	tracks = raw.converted(loop_ticks).values()

	# Add metadata to make it all look like a ripped song.

	# Track count and instrument count. Since we are going to use
	# an existing map, there are no local instruments, so we write 0.
	result = chr(len(tracks)) + '\x00\x00\x00'

	result += as_little_endian(as_pointer(instrument_map_offset), 4)
	offset = 4 * (2 + len(tracks))
	for track in tracks:
		result += as_little_endian(offset, 4)
		offset += len(track)
	for track in tracks:
		result += track

	return Chunk(result), status
